{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d346514d",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "405c8dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import warnings\n",
    "from tqdm import tqdm\n",
    "from torch.autograd import Variable\n",
    "from sklearn.metrics import mean_absolute_error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1077d51d",
   "metadata": {},
   "source": [
    "### Auglichem imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a60964fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from auglichem.crystal import (PerturbStructureTransformation,\n",
    "                               RotationTransformation,\n",
    "                               SwapAxesTransformation,\n",
    "                               TranslateSitesTransformation,\n",
    "                               SupercellTransformation,\n",
    ")\n",
    "from auglichem.crystal.data import CrystalDatasetWrapper\n",
    "from auglichem.crystal.models import CrystalGraphConvNet as CGCNN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2abb209",
   "metadata": {},
   "source": [
    "### Set up dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1d49da9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data found at: ./data_download/lanths\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████| 3332/3332 [00:00<00:00, 55302.86it/s]\n"
     ]
    }
   ],
   "source": [
    "# Create transformation\n",
    "transforms = [\n",
    "        PerturbStructureTransformation(distance=0.1, min_distance=0.01),\n",
    "        RotationTransformation(axis=[0,0,1], angle=90),\n",
    "        SwapAxesTransformation(),\n",
    "        TranslateSitesTransformation(indices_to_move=[0], translation_vector=[1,0,0],\n",
    "                                     vector_in_frac_coords=True),\n",
    "        SupercellTransformation(scaling_matrix=[[1,0,0],[0,1,0],[0,0,1]]),\n",
    "]\n",
    "\n",
    "# Initialize dataset object\n",
    "dataset = CrystalDatasetWrapper(\"lanthanides\", batch_size=256,\n",
    "                                valid_size=0.1, test_size=0.1, cgcnn=True)\n",
    "\n",
    "# Get train/valid/test splits as loaders\n",
    "train_loader, valid_loader, test_loader = dataset.get_data_loaders(transform=transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04580c5c",
   "metadata": {},
   "source": [
    "### Initialize model with task from data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "db898fdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/clo/miniforge3/envs/auglichem/lib/python3.8/site-packages/pymatgen/io/cif.py:1165: UserWarning: Issues encountered while parsing CIF: Some fractional co-ordinates rounded to ideal values to avoid issues with finite precision.\n",
      "  warnings.warn(\"Issues encountered while parsing CIF: %s\" % \"\\n\".join(self.warnings))\n",
      "/var/folders/mh/wfzfv8nd3g7_8w30_pjfbhtr0000gn/T/ipykernel_59279/3365660636.py:6: RuntimeWarning: CrystalDataset._cgcnn must be set to True to use CGCNN properly.\n",
      "  model = CGCNN(orig_atom_fea_len, nbr_fea_len)\n"
     ]
    }
   ],
   "source": [
    "# Get model\n",
    "structures, _, _ = dataset[0]\n",
    "orig_atom_fea_len = structures[0].shape[-1]\n",
    "nbr_fea_len = structures[1].shape[-1]\n",
    "\n",
    "model = CGCNN(orig_atom_fea_len, nbr_fea_len)\n",
    "\n",
    "#model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe0ae6db",
   "metadata": {},
   "source": [
    "### Initialize traning loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b285f1bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31e285aa",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a3635ea9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "79it [07:15,  5.51s/it]\n"
     ]
    }
   ],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    for epoch in range(1):\n",
    "        for bn, (data, target, _) in tqdm(enumerate(train_loader)):\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            input_var = (Variable(data[0]),\n",
    "                         Variable(data[1]),\n",
    "                         data[2],\n",
    "                         data[3])\n",
    "            \n",
    "            # data -> GPU\n",
    "            #input_var = (Variable(data[0].cuda()),\n",
    "            #             Variable(data[1].cuda()),\n",
    "            #             data[2].cuda(),\n",
    "            #             data[3])\n",
    "\n",
    "            pred = model(*input_var)\n",
    "            loss = criterion(pred, target)\n",
    "            #loss = criterion(pred, target.cuda())\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3fd1463",
   "metadata": {},
   "source": [
    "### Test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "841f804a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_loader, validation=False):\n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        preds = torch.Tensor([])\n",
    "        targets = torch.Tensor([])\n",
    "        for data, target, _ in test_loader:\n",
    "            input_var = (Variable(data[0]),\n",
    "                         Variable(data[1]),\n",
    "                         data[2],\n",
    "                         data[3])\n",
    "            \n",
    "            # data -> GPU\n",
    "            #input_var = (Variable(data[0].cuda()),\n",
    "            #             Variable(data[1].cuda()),\n",
    "            #             data[2].cuda(),\n",
    "            #             data[3])\n",
    "\n",
    "            pred = model(*input_var)\n",
    "            \n",
    "            preds = torch.cat((preds, pred.cpu().detach()))\n",
    "            targets = torch.cat((targets, target))\n",
    "            \n",
    "        mae = mean_absolute_error(preds, targets)   \n",
    "    set_str = \"VALIDATION\" if(validation) else \"TEST\"\n",
    "    print(\"{0} MAE: {1:.3f}\".format(set_str, mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fed02092",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/clo/miniforge3/envs/auglichem/lib/python3.8/site-packages/pymatgen/io/cif.py:1165: UserWarning: Issues encountered while parsing CIF: Some fractional co-ordinates rounded to ideal values to avoid issues with finite precision.\n",
      "  warnings.warn(\"Issues encountered while parsing CIF: %s\" % \"\\n\".join(self.warnings))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VALIDATION MAE: 0.280\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/clo/miniforge3/envs/auglichem/lib/python3.8/site-packages/pymatgen/io/cif.py:1165: UserWarning: Issues encountered while parsing CIF: Some fractional co-ordinates rounded to ideal values to avoid issues with finite precision.\n",
      "  warnings.warn(\"Issues encountered while parsing CIF: %s\" % \"\\n\".join(self.warnings))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TEST MAE: 0.274\n"
     ]
    }
   ],
   "source": [
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8aa5d9b",
   "metadata": {},
   "source": [
    "### Model saving/loading example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4a6acbca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save model\n",
    "torch.save(model.state_dict(), \"./saved_models/example_cgcnn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8f0b592a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/mh/wfzfv8nd3g7_8w30_pjfbhtr0000gn/T/ipykernel_59279/1379233262.py:6: RuntimeWarning: CrystalDataset._cgcnn must be set to True to use CGCNN properly.\n",
      "  model = CGCNN(orig_atom_fea_len, nbr_fea_len)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VALIDATION MAE: 6.478\n",
      "TEST MAE: 6.513\n"
     ]
    }
   ],
   "source": [
    "# Instantiate new model and evaluate\n",
    "structures, _, _ = dataset[0]\n",
    "orig_atom_fea_len = structures[0].shape[-1]\n",
    "nbr_fea_len = structures[1].shape[-1]\n",
    "\n",
    "model = CGCNN(orig_atom_fea_len, nbr_fea_len)\n",
    "\n",
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7901e22c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VALIDATION MAE: 0.280\n",
      "TEST MAE: 0.274\n"
     ]
    }
   ],
   "source": [
    "# Load saved model and evaluate\n",
    "model.load_state_dict(torch.load(\"./saved_models/example_cgcnn\"))\n",
    "evaluate(model, valid_loader, validation=True)\n",
    "evaluate(model, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a276d25",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
